#include <cstdio>
#include <cstdlib>
#include "dataset.h"

static struct {
  int magicNumber, num, width, height;
  unsigned char data[10000][28 * 28];
} imageSet;
static struct {
  int magicNumber, num;
  unsigned char data[10000];
} labelSet;

static const char imagePath[] = "data/t10k-images.idx3-ubyte";
static const char labelPath[] = "data/t10k-labels.idx1-ubyte";

bool getDataset(unsigned char **imageData, unsigned char **labelData) {
  // 10000 张测试集
  FILE *fpImageSet = fopen(imagePath, "rb+");
  if (fpImageSet == nullptr) {
    printf("unable open t10k-images.idx3-ubyte");
    exit(0);
  }

  // 测试集的 label
  FILE *fpLabelSet = fopen(labelPath, "rb+");
  if (fpLabelSet == nullptr) {
    printf("unable open t10k-labels.idx1-ubyte");
    exit(0);
  }

  // 读取并打印字符集的幻数、数量、宽、高
  fread(&imageSet.magicNumber, sizeof(int), 1, fpImageSet);
  fread(&imageSet.num, sizeof(int), 1, fpImageSet);
  fread(&imageSet.width, sizeof(int), 1, fpImageSet);
  fread(&imageSet.height, sizeof(int), 1, fpImageSet);
  printf("#### 图像信息 ####\n");
  printf("Magic Number = %0x\n", imageSet.magicNumber);
  printf("图像总数 = %0x\n", imageSet.num);
  printf("宽 = %0x\n", imageSet.width);
  printf("高 = %0x\n", imageSet.height);

  // 读取t10k-labels.idx1-ubyte的幻数、数量
  fread(&labelSet.magicNumber, sizeof(int), 1, fpLabelSet);
  fread(&labelSet.num, sizeof(int), 1, fpLabelSet);
  // 读取t10k-labels.idx1-ubyte的前28×28个label
  printf("#### 标签信息 ####\n");
  printf("Magic Number = %0x\n", labelSet.magicNumber);
  printf("总数 = %0x\n", labelSet.num);

  // 读取 10000 张图片
  for (int i = 0; i < 10000; ++i) {
    fread(imageSet.data[i], sizeof(unsigned char), 28 * 28, fpImageSet);
    fread(labelSet.data + i, sizeof(unsigned char), 1, fpLabelSet);
  }
  *imageData = (unsigned char *)imageSet.data;
  *labelData = labelSet.data;

  return true;
}

void colorShell(unsigned char *image) {
  for (int i = 0; i < 28; ++i) {
    for (int j = 0; j < 28; ++j) {
      int val = image[i * 28 + j];
      printf("\033[48;2;%d;%d;%dm", val, val, val);
      printf("%3d", val);
    }
    printf("\033[0m\n");
  }
}
